# 去除表达量全为0的基因, 去除仅在1%(默认)的细胞中表达的基因.
# 关于 genes x cells 的表达矩阵, 对基因怎么进行过滤不同的公司/人有不同的理解.
# 有的公司根据 FindAllMarkers 中的差异基因进行过滤, 有的公司仅进行简单的过滤：去除表达量全为0的基因, 仅在少数细胞(1%)中表达的基因.
# 然后去看 SCENIC::geneFiltering 的源代码, 还附加了基于 feather 文件的筛选条件(矩阵中的基因也要出现在feather文件中)
# 怎么进行 gene filter 还是需要 case by case, 个人感觉是在不对结果造成影响的前提下, 尽可能的缩小矩阵规模, 提升运行速度.
filter_matrix <- function(project = NULL, 
                          minCountsPerGene = 1,
                          minPercent = 0.01){
  exprMatix          <- as.matrix(Seurat::GetAssayData(object = project,slot = "counts"))
  nCountsPerGene     <- rowSums(exprMatix, na.rm = T)
  nCellsPerGene      <- rowSums(exprMatix > 0, na.rm = T)
  genesLeft_minReads <- names(nCountsPerGene)[which(nCountsPerGene > minCountsPerGene)]
  nCellsPerGene2     <- nCellsPerGene[genesLeft_minReads]
  minSamples         <- ncol(exprMatix) * minPercent
  genesLeft_minCells <- names(nCellsPerGene2)[which(nCellsPerGene2 > minSamples)]
  filtered_matrix    <- exprMatix[genesLeft_minCells,]
  return(filtered_matrix)
}

# get colors for plots, total 63 colors can use
get_colors <- function(colors_len = NULL){
  cols <- c(RColorBrewer::brewer.pal(12,"Paired"),
            RColorBrewer::brewer.pal(9,"Set1"),
            RColorBrewer::brewer.pal(8,"Set2"),
            RColorBrewer::brewer.pal(12,"Set3"),
            RColorBrewer::brewer.pal(8,"Pastel2"),
            RColorBrewer::brewer.pal(9,"Pastel1"),
            RColorBrewer::brewer.pal(8,"Accent"))
  if(missing(colors_len)){
    col <- unique(cols)[-c(17,23)]
  }else{
    col <- unique(cols)[-c(17,23)][1:colors_len]
  }
  return(col)
}

# save plot
sava_pdf <- function(obj = NULL,file = NULL,
                     width = 6, height = 5){
  pdf(file = file, width = width, height = height)
  print(obj)
  dev.off()
  file_png <- gsub(pattern = "pdf",replacement = "png",x = file)
  png(file = file_png, width = width, height = height, res = 300, units = "in")
  print(obj)
  dev.off()
}

# plot umap and tsne for all TFs. 以转录因子的 auc score 对 tsne 和 umap 图进行着色, 全部的转录因子画在10张图上
plot_reductions <- function(project = NULL,
                            auc_matrix = NULL,
                            TF_list = NULL){
  project[["SCENIC"]] <- Seurat::CreateAssayObject(counts = auc_matrix)
  Seurat::DefaultAssay(project) <- "SCENIC"
  
  if(is.null(TF_list)){
    chunk <- function(x,n){split(x, factor(sort(rank(x)%%n)))}
    nblock <- chunk(rownames(project), 10)
    for(index in 1:length(nblock)){
      p1 <- Seurat::FeaturePlot(object = project, features = nblock[[index]], reduction='tsne', pt.size=0.1)
      sava_pdf(obj = p1,file = paste0('SCENIC_tsne_',index,'.pdf'),width = 15,height = 12)
      p1 <- Seurat::FeaturePlot(object = project, features = nblock[[index]], reduction='umap', pt.size=0.1)
      sava_pdf(obj = p1,file = paste0('SCENIC_umap_',index,'.pdf'),width = 15,height = 12)
    }
  }else{
    for(TF in TF_list){
      p1 <- Seurat::FeaturePlot(object = project, features = TF, reduction='tsne', pt.size=0.1)
      sava_pdf(obj = p1,file = paste0('SCENIC_tsne_',TF,'.pdf'),width = 8,height = 6)
      p1 <- Seurat::FeaturePlot(object = project, features = TF, reduction='umap', pt.size=0.1)
      sava_pdf(obj = p1,file = paste0('SCENIC_umap_',TF,'.pdf'),width = 8,height = 6)
    }
  }
}

# plot violinplot
plot_violinplot <- function(project = NULL,
                            auc_matrix = NULL,
                            group_by = "Cluster",
                            TF_list = NULL){
  project[["SCENIC"]] <- Seurat::CreateAssayObject(counts = auc_matrix)
  Seurat::DefaultAssay(project) <- "SCENIC"
  col_list <- get_colors(colors_len = length(unique(project[[group_by]][,1])))
  if(is.null(TF_list)){
    chunk <- function(x,n){split(x, factor(sort(rank(x)%%n)))}
    nblock <- chunk(rownames(project), 10)
    for(index in 1:length(nblock)){
      p1 <- Seurat::VlnPlot(project, cols=col_list, features = nblock[[index]], group.by = group_by, pt.size = 0) + 
            labs(x = NULL,y = "AUC score") + theme(axis.line = element_line(colour="black"),
                                                   panel.grid = element_blank(),
                                                   legend.position = 'none',
                                                   plot.margin = margin(c(20,20,20,20),"mm"),
                                                   strip.background = element_rect(colour="#f0f0f0",fill="#f0f0f0"),
                                                   strip.text = element_text(face="bold"))
      sava_pdf(obj = p1,file = paste0('SCENIC_violinplot_',index,'.pdf'),width = 15,height = 12)
    }
  }else{
    for(TF in TF_list){
      p1 <- Seurat::VlnPlot(project, cols=col_list, features = TF, group.by = group_by, pt.size = 0) + 
            labs(x = NULL,y = "AUC score") + theme(axis.line = element_line(colour="black"),
                                                   panel.grid = element_blank(),
                                                   legend.position = 'none',
                                                   plot.margin = margin(c(20,20,20,20),"mm"),
                                                   strip.background = element_rect(colour="#f0f0f0",fill="#f0f0f0"),
                                                   strip.text = element_text(face="bold"))
      sava_pdf(obj = p1,file = paste0('SCENIC_violinplot_',TF,'.pdf'),width = 8,height = 6)
    }
  }
}

# plot ggridgesplot
plot_ggridgesplot <- function(project = NULL,
                              auc_matrix = NULL,
                              group_by = "Cluster",
                              TF_list = NULL){
  project[["SCENIC"]] <- Seurat::CreateAssayObject(counts = auc_matrix)
  Seurat::DefaultAssay(project) <- "SCENIC"
  col_list <- get_colors(colors_len = length(unique(project[[group_by]][,1])))
  if(is.null(TF_list)){
    chunk <- function(x,n){split(x, factor(sort(rank(x)%%n)))}
    nblock <- chunk(rownames(project), 10)
    for(index in 1:length(nblock)){
      p1 <- Seurat::RidgePlot(project, cols=col_list, features = nblock[[index]], group.by = group_by) + 
            labs(x = "AUC score",y = NULL) + theme(axis.line = element_line(colour="black"),
                                                   panel.grid = element_blank(),
                                                   legend.position = 'none',
                                                   plot.margin = margin(c(20,20,20,20),"mm"),
                                                   strip.background = element_rect(colour="#f0f0f0",fill="#f0f0f0"),
                                                   strip.text = element_text(face="bold"))
      sava_pdf(obj = p1,file = paste0('SCENIC_ggridges_',index,'.pdf'),width = 15,height = 12)
    }
  }else{
    for(TF in TF_list){
      p1 <- Seurat::RidgePlot(project, cols=col_list, features = TF, group.by = group_by) + 
            labs(x = "AUC score",y = NULL) + theme(axis.line = element_line(colour="black"),
                                                   panel.grid = element_blank(),
                                                   legend.position = 'none',
                                                   plot.margin = margin(c(20,20,20,20),"mm"),
                                                   strip.background = element_rect(colour="#f0f0f0",fill="#f0f0f0"),
                                                   strip.text = element_text(face="bold"))
      sava_pdf(obj = p1,file = paste0('SCENIC_ggridges_',TF,'.pdf'),width = 8,height = 6)
    }
  }
}

# 绘图函数：plot_reductions() 、plot_violinplot() 和 plot_ggridgesplot 都可以接 TF_list 来单独展示想要展示的 转录因子, 若不给 TF_list 的话, 就默认全画了
# TF_list <- c("ARID3A(63g)" ,"BACH1(625g)" ,"BACH2(5g)")
# plot_reductions(project = project,auc_matrix = data,TF_list = TF_list)
# plot_violinplot(project = project,auc_matrix = data,TF_list = TF_list)
# plot_ggridgesplot(project = project,auc_matrix = data,TF_list = TF_list)

# 单细胞水平的热图
plot_sc_heatmap <- function(project = NULL,
                         auc_matrix = NULL,
                         group_by = "Cluster"){
  auc_matrix <- auc_matrix[apply(auc_matrix, 1, sd) != 0,]
  auc_matrix <- t(scale(t(auc_matrix),center = T, scale=T))
  column_df <- data.frame(Group = project[[group_by]][,1], 
                          Sample = project@meta.data$Sample)
  Group_color <- get_colors(colors_len = length(unique(project[[group_by]][,1])))
  names(Group_color) <- unique(project[[group_by]][,1])
  Sample_color <- get_colors(colors_len = length(unique(project[["Sample"]][,1])))
  names(Sample_color) <- unique(project[["Sample"]][,1])
  # library(circlize) # 可以自己调整热图颜色范围
  # col_fun <- circlize::colorRamp2(c(-2,0,6),gplots::colorpanel(75, low="blue", mid="white",high="red"))
  column_annotation <- ComplexHeatmap::HeatmapAnnotation(df = column_df, col = list(Group = Group_color, Sample = Sample_color))
  complexheatmap <- ComplexHeatmap::Heatmap(auc_matrix, 
                          # col = col_fun,
                            col = gplots::colorpanel(75, low="blue", mid="white",high="red"), 
                            name = "heatmap", 
                            heatmap_legend_param = list(legend_direction = "horizontal", 
                                                        legend_width = unit(10, "cm"), 
                                                        title_position = "lefttop"),
                            column_names_side = "top", cluster_columns=TRUE, 
                            cluster_rows=TRUE, show_column_names = FALSE, 
                            row_names_gp = grid::gpar(fontsize = 12), 
                            top_annotation = column_annotation)
  p <- ComplexHeatmap::draw(complexheatmap, heatmap_legend_side = "bottom")
  if(dim(auc_matrix)[1] < 100){
    sava_pdf(obj = p,file="SCENIC_AUC_score_heatmap.pdf", width = 20, height = 15)
  }else{
    sava_pdf(obj = p,file="SCENIC_AUC_score_heatmap.pdf", width = 20, height = 30)
  }
}

# 按分组绘制, 可以是 CellType 、 Cluster 也可以是 Sample, 绘制方法为对每个转录因子分组取平均数
plot_group_heatmap <- function(project = NULL,
                               auc_matrix = NULL,
                               group_by = "Cluster"){
  cellsPerGroup <- split(colnames(project), project[[group_by]][,1])
  regulonActivity_byGroup <- sapply(cellsPerGroup,function(cells){rowMeans(auc_matrix[,cells])})
  regulonActivity_byGroup_Scaled <- t(scale(t(regulonActivity_byGroup),center = T, scale=T))
  column_df <- data.frame(Group = unique(project[[group_by]][,1]))
  column_df$Group <- base::factor(x = column_df$Group,levels = unique(project[[group_by]][,1]))
  Group_color <- get_colors(colors_len = length(unique(project[[group_by]][,1])))
  names(Group_color) <- unique(project[[group_by]][,1])
  # library(circlize) # 可以自己调整热图颜色范围
  # col_fun <- circlize::colorRamp2(c(-2,0,6),gplots::colorpanel(75, low="blue", mid="white",high="red"))
  column_annotation <- ComplexHeatmap::HeatmapAnnotation(df = column_df, col = list(Group = Group_color))
  complexheatmap <- ComplexHeatmap::Heatmap(regulonActivity_byGroup_Scaled, 
                                          # col = col_fun,
                                            col = gplots::colorpanel(75, low="blue", mid="white",high="red"), 
                                            name = "heatmap", 
                                            heatmap_legend_param = list(legend_direction = "horizontal", 
                                                                        legend_width = unit(10, "cm"), 
                                                                        title_position = "lefttop"),
                                            column_names_side = "top", cluster_columns=TRUE, 
                                            cluster_rows=TRUE, show_column_names = FALSE, 
                                            row_names_gp = grid::gpar(fontsize = 12), 
                                            top_annotation = column_annotation)
  p <- ComplexHeatmap::draw(complexheatmap, heatmap_legend_side = "bottom")
  if(dim(auc_matrix)[1] < 100){
    sava_pdf(obj = p,file="SCENIC_AUC_group_heatmap.pdf", width = 20, height = 15)
  }else{
    sava_pdf(obj = p,file="SCENIC_AUC_group_heatmap.pdf", width = 20, height = 30)
  }
  regulonActivity_byGroup <- regulonActivity_byGroup %>% as.data.frame() %>% tibble::rownames_to_column(var = "Regulon")
  write.table(x = regulonActivity_byGroup,file = "finalout_auc_group.xls",quote = F,sep = "\t",row.names = F,col.names = T)
}